"""GPUStack v0.7.0

Revision ID: cbbc03c88985
Revises: c45e397531d1
Create Date: 2025-06-09 15:07:05.299418

"""
import json
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
import sqlmodel
import gpustack


# revision identifiers, used by Alembic.
revision: str = 'cbbc03c88985'
down_revision: Union[str, None] = 'c45e397531d1'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    migrate_model_instance_distributed_servers()

    with op.batch_alter_table('workers', schema=None) as batch_op:
        batch_op.add_column(sa.Column('worker_uuid', sa.String(length=255), nullable=False, server_default=''))
    with op.batch_alter_table('model_instances') as batch_op:
        batch_op.add_column(sa.Column('ports', sa.JSON(), nullable=True))
        batch_op.add_column(sa.Column('gpu_addresses', sa.JSON(), nullable=True))

    op.create_index('ix_workers_worker_uuid', 'workers', ['worker_uuid'], unique=False)

    # ### end Alembic commands ###


def downgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    with op.batch_alter_table('workers', schema=None) as batch_op:
        batch_op.drop_column('worker_uuid')
    with op.batch_alter_table('model_instances') as batch_op:
        batch_op.drop_column('ports')
        batch_op.drop_column('gpu_addresses')

    op.drop_index('ix_workers_worker_uuid', table_name='workers')
    # ### end Alembic commands ###


def migrate_model_instance_distributed_servers():
    conn = op.get_bind()
    model_instances = conn.execute(
        sa.text("SELECT id, distributed_servers FROM model_instances WHERE distributed_servers IS NOT NULL")
    ).fetchall()

    for row in model_instances:
        id = row[0]
        dist_data = row[1]

        if isinstance(dist_data, str):
            try:
                dist_data = json.loads(dist_data)
            except json.JSONDecodeError:
                continue

        if dist_data is None:
            continue

        rpc_servers = dist_data.get("rpc_servers", []) or []
        ray_actors = dist_data.get("ray_actors", []) or []

        if not rpc_servers and not ray_actors:
            continue

        def convert_to_subordinate(entry):
            new_entry = dict(entry)
            # Convert rpc_server.gpu_index to subordinate_worker.gpu_indexes
            if "gpu_index" in new_entry:
                new_entry["gpu_indexes"] = [new_entry.pop("gpu_index")]
            return new_entry

        # Convert RPC servers and Ray actors to subordinate workers
        subordinate_workers = [
            convert_to_subordinate(worker) for worker in rpc_servers + ray_actors
        ]

        new_dist_data = {
            "mode": "delegated",
            "download_model_files": bool(ray_actors),
            "subordinate_workers": subordinate_workers
        }

        conn.execute(
            sa.text(
                "UPDATE model_instances SET distributed_servers = :dist WHERE id = :id"
            ),
            {"dist": json.dumps(new_dist_data), "id": id}
        )
